Make decision tree from iris data

Taken from Google's Visualizing a Decision Tree - Machine Learning Recipes #2


In [1]:
from sklearn import tree
from sklearn.datasets import load_iris

In [2]:
iris = load_iris()

In [3]:
type(iris)


Out[3]:
sklearn.datasets.base.Bunch

In [4]:
isinstance(iris, dict)


Out[4]:
True

In [5]:
iris.keys()


Out[5]:
dict_keys(['target_names', 'target', 'DESCR', 'feature_names', 'data'])

In [6]:
iris.feature_names


Out[6]:
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [7]:
iris.target_names


Out[7]:
array(['setosa', 'versicolor', 'virginica'], 
      dtype='<U10')

In [8]:
iris['target']


Out[8]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [9]:
type(iris['target'])


Out[9]:
numpy.ndarray

In [10]:
for i in range(len(iris.target)):
    if i < 5:
        print('Example {}: label {}, features {}'.format(i, iris.target[i], iris.data[i]))


Example 0: label 0, features [ 5.1  3.5  1.4  0.2]
Example 1: label 0, features [ 4.9  3.   1.4  0.2]
Example 2: label 0, features [ 4.7  3.2  1.3  0.2]
Example 3: label 0, features [ 4.6  3.1  1.5  0.2]
Example 4: label 0, features [ 5.   3.6  1.4  0.2]

In [11]:
import numpy as np

In [12]:
test_idx = [0, 50, 100]  # these are the rows to be removed from the training data

# remove the same rows from the actual data
# Note: without axis=0, returns just a  list, not a list of lists
# ie we want this:
# [[ 4.9,  3. ,  1.4,  0.2],
#       [ 4.7,  3.2,  1.3,  0.2],
#       [ 4.6,  3.1,  1.5,  0.2], …]
# and not this:
# [4.9,  3. ,  1.4,  0.2,  4.7,  3.2,  1.3,  0.2, 4.6,  3.1,  1.5,  0.2, …]
train_data = np.delete(iris.data, test_idx, axis=0)

# np.delete() remove the above 3 indices from array iris.target
# Note: here the axis= arg doesn't matter, as only a 1 interger per item in list
train_target = np.delete(iris.target, test_idx)

In [13]:
# See how rows have been rm'd
len(iris.target)


Out[13]:
150

In [14]:
len(train_data)  # the three taken out


Out[14]:
147

In [15]:
len(train_target)


Out[15]:
147

In [16]:
test_target = iris.target[test_idx]

In [17]:
test_target  # only three


Out[17]:
array([0, 1, 2])

In [18]:
test_data = iris.data[test_idx]

In [19]:
test_data


Out[19]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 7. ,  3.2,  4.7,  1.4],
       [ 6.3,  3.3,  6. ,  2.5]])

In [20]:
# Note: on numpy array
l = [1, 4, 5,6,8, 999, 44, 6, 7, 10]
a = np.array(l)

# now we can pull out items with a list of indices/rows
idx = [0, 4, 6]
a[idx]


Out[20]:
array([ 1,  8, 44])

In [21]:
# train model
clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)


Out[21]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')

In [22]:
# make prediction
clf.predict(test_data)


Out[22]:
array([0, 1, 2])

In [23]:
# matches input labels?
clf.predict(test_data) == test_target


Out[23]:
array([ True,  True,  True], dtype=bool)

Write tree output


In [25]:
from sklearn.externals.six import StringIO
import pydotplus # note installed pydotplus for Py3 compatibility

In [26]:
dot_data = StringIO()

tree.export_graphviz(clf, 
                     out_file=dot_data, 
                     feature_names=iris.feature_names, 
                     class_names=iris.target_names, 
                     filled=True, 
                     rounded=True, 
                     impurity=False)

In [27]:
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())

In [29]:
# graphviz installed on mac with `brew install graphviz`
graph.write_pdf('iris.pdf')

# open -a preview ~/ipython/tensorflow/iris.pdf


Out[29]:
True

In [32]:
# now check the rows withheld for testing
# check against rules in graphic tree
test_data[0], test_target[0]  # we know is a setosa


Out[32]:
(array([ 5.1,  3.5,  1.4,  0.2]), 0)

In [34]:
iris.feature_names, iris.target_names


Out[34]:
(['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 array(['setosa', 'versicolor', 'virginica'], 
       dtype='<U10'))

In [36]:
test_data[1], test_target[1]  # we know is a versicolor


Out[36]:
(array([ 7. ,  3.2,  4.7,  1.4]), 1)

In [38]:
test_data[2], test_target[2]  # we know is a virginica

# all test to true!


Out[38]:
(array([ 6.3,  3.3,  6. ,  2.5]), 2)